import jax
from learned_optimization.tasks.datasets import image
from learned_optimization.learned_optimizers import mlp_lopt
from learned_optimization.tasks.fixed import image_mlp
from tasks import *
from dataset import *
from algorithms import *
import logging
import os 


def parse_lopt(lopt=None, config=None):
  if lopt == 'meta6':
    return MetaSGHMC6(config=config)
  else:
    raise NotImplementedError('Inappropriate dataset name')


def parse_eval_task(task, batch_size , config=None):
  hidden_size = [20, 20]
  weight_decay = config.eval.wd
  if task == 'mnist':
    datasets = image.mnist_datasets(batch_size=batch_size)
    task = _MLPImageTask(datasets=datasets, hidden_sizes=hidden_size, weight_decay=weight_decay) 
    return task
  elif task == 'emnist':
    datasets = emnist_datasets(batch_size=batch_size)
    task = _MLPImageTask(datasets=datasets, hidden_sizes=[256, 128], weight_decay=weight_decay) 
    return task
  elif task == 'fmnist':
    datasets = image.fashion_mnist_datasets(batch_size=batch_size)
    task = _MLPImageTask(datasets=datasets, hidden_sizes=hidden_size, weight_decay=weight_decay) 
    return task
  elif task == 'fmnist_conv':
    datasets = image.fashion_mnist_datasets(batch_size=batch_size)
    base_model_fn = _new_cross_entropy_pool_loss([32], jax.nn.relu, num_classes=10)
    task = _ConvTask(base_model_fn, datasets, weight_decay)
    return task
  elif task == 'ResNet':
    datasets = image.cifar10_datasets(batch_size=batch_size, normalize_mean=(0.49,0.48,0.44), normalize_std=(0.2,0.2,0.2) )
    base_model_fn = _fc_resnet_loss_fn_tmp(num_classes=datasets.extra_info['num_classes'])
    task = _ResNetTask(base_model_fn=base_model_fn, datasets=datasets, weight_decay=weight_decay)
    return task
  elif task == 'cifar100':
    datasets = image.cifar100_datasets(batch_size=batch_size)
    base_model_fn = _fc_resnet_loss_fn_tmp(num_classes=datasets.extra_info['num_classes'])
    task = _ResNetTask(base_model_fn=base_model_fn, datasets=datasets, weight_decay=weight_decay)
    return task
  elif task == 'tiny64_resnet56':
    datasets = tinyimagenet32_datasets(batch_size=batch_size, image_size=(64, 64))
    base_model_fn = _fc_resnet_loss_fn_tmp110(num_classes=datasets.extra_info['num_classes'])
    task = _ResNetTask(base_model_fn=base_model_fn, datasets=datasets, weight_decay=weight_decay)
    return task
  else:
    raise NotImplementedError('Inappropriate dataset name')


def write_json(obj, path: str, verbose: bool=False, **kwargs) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, **kwargs)
    if verbose:
        print(f"Results saved to {path}.")
        

def setup_logging(train_dir):
    logging.root.handlers = []
    logging_config = {
        'level': logging.INFO,
        'format': '[%(asctime)s %(filename)s: %(lineno)3d]: %(message)s',
        'datefmt': '%Y-%m-%d %H:%M:%S'
    }
    logging_config['filename'] = os.path.join(train_dir, 'stdout.log')
    # Configure logging
    logging.basicConfig(**logging_config)